import torch
import lzs_cuda_test


i = torch.randn(8, 256).cuda(0)
o = lzs_cuda_test.lzs_test(i)

print("The shape of Tensor `i` is: ", i.shape, "; The device of Tensor `i` is: ", i.device)
print("The shape of Tensor `o` is: ", o.shape, "; The device of Tensor `o` is: ", o.device)

ii = i.cpu()
oo = o.cpu()
print("The shape of Tensor `ii` is: ", ii.shape, "; The device of Tensor `ii` is: ", ii.device)
print("The shape of Tensor `oo` is: ", oo.shape, "; The device of Tensor `oo` is: ", oo.device)
print("The number of equal CUDA elements is: ", (i == o).sum().item())
print("The number of equal CPU elements is: ", (ii == oo).sum().item())

